import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader

from malwareAnalysis.utils.Process import process
from malwareAnalysis.api import InputDataSet
from malwareAnalysis.module.CnnModule import DenseNet_BC

from malwareAnalysis.core.KmeansPlus import kmeansplus


def IdentifyData(moduleFrom, dataFrom, resultTo):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cnn = DenseNet_BC().to(device)
    model = cnn.load_state_dict(torch.load(moduleFrom))
    dataset = InputDataSet.inputDataSet(dataFrom)
    dataLoader = DataLoader(dataset, batch_size=128, shuffle=True)
    label_list = process(cnn, dataLoader, 0.98)

    # print(label_list)
    unknow_data = pd.DataFrame()

    inputData = pd.read_csv(dataFrom, index_col=None)
    length = len(label_list)
    num_unknow_data = 0
    origin_to_new = np.zeros((length,2),dtype=int)
    # 未知协议的数目
    for i in range(length):
        if(label_list[i] == 404):
            origin_to_new[num_unknow_data,0] = i
            num_unknow_data+=1
            unknow_data = unknow_data.append(inputData.iloc[i,:],ignore_index=True)
    
    # 将未知协议放入到新的dataSet中使用kmeans进行分析
    if(len(unknow_data) != 0):
        print("kmeans......")
        dataSet = np.array(unknow_data)
        output = kmeansplus(dataSet)
        print(output)